package jgo.tools.compiler
package interm
package expr

import types._
import instr._
import instr.TypeConversions._
import codeseq._

private object ConditionalExpr {
  /**
   * An object specifying where control should flow to in a certain situation.
   * To avoid generating superfluous jumps (e.g., `Goto(lbl1); Lbl(lbl1)`),
   * we make use of a special "fallthrough" target called `Fall`.
   */
  sealed abstract class Target {
    def replaceFall(end: Label): Target = this match {
      case j: Jump => j
      case Fall    => Jump(end)
    }
  }
  
  /**
   * Indicates that control should flow to the specified label.
   */
  case class Jump(lbl: Label) extends Target
  
  /**
   * Indicates that control should flow to the "next" instruction --
   * namely that instruction to which code would flow absent any Branch or
   * Goto instructions.
   */
  case object Fall extends Target
  
  implicit def lbl2target(lbl: Label): Target = Jump(lbl)
}
import ConditionalExpr._

/**
 * An expression whose value can determine the flow of control through a
 * function.
 */
sealed abstract class ConditionalExpr extends Expr {
  //ALERT: typeOf would still be null at the time this statement would execute.
  //require(typeOf.underlying == BoolType, "BoolExpr must have underlying BoolType")
  
  /**
   * Produces code that branches to one of two targets based on the truth value
   * of this conditional expression.
   */
  private[expr] def branch(trueBr: Target, falseBr: Target): CodeBuilder
  
  /**
   * Produces code that evaluates this conditional expression and pushes the result
   * onto the operand stack.
   */
  private[expr] def evalUnder = {
    val g   = new LabelGroup
    val tr  = new Label("push true", g)
    val end = new Label("end of push bool", g)
    
    branch(tr, Fall) |+| PushBool(false) |+| Goto(end) |+| Lbl(tr) |+| PushBool(true) |+| Lbl(end)
  }
  
  /**
   * Computes the evaluation code of this conditional expression from the
   * eval-underlying code.  We cannot inherit from `EvalFromUnderlyingExpr`
   * because that would permit that trait to "escape its defining scope".
   * 
   * The Scala compiler prohibits public types from extending private ones,
   * a restriction I find non-cohesive.  I will check if this restriction
   * is mandated by the JVM and, if it is not, will file a bug/enhancement
   * report.
   * 
   * Update: It appears that this is not a JVM restriction.
   */
  private[expr] def eval = {
    def evalWrappedIn(t: Type): CodeBuilder = t match {
      case wt: WrappedType => evalWrappedIn(wt.referent) |+| Wrap(wt)
      case ta: TypeAlias   => evalWrappedIn(ta.effective) //or ta.referent
      case _               => evalUnder
    }
    evalWrappedIn(typeOf)
  }
  
  /**
   * Generates code that branches to the specified label if this conditional expression
   * evaluates to true.  If this boolean expression evaluates to false, control falls
   * through in the usual manner to whatever instruction follows the code produced by
   * this method.
   */
  def branchTo(lbl: Label): CodeBuilder = {
    val g   = new LabelGroup
    val end = new Label("end branchTo", g)
    branch(lbl, Fall)
  }
  
  /**
   * Generates code that executes the specified code if this conditional expression
   * evaluates to true.
   */
  def mkIf(ifBranch: CodeBuilder): CodeBuilder = {
    val g   = new LabelGroup
    val end = new Label("end if", g)
    val t   = new Label("if branch", g)
    branch(Fall, end) |+| Lbl(t) |+| ifBranch |+| Lbl(end)
  }
  
  /**
   * Generates code that executes the contents of the first parameter if
   * this conditional expression evaluates to true; those of the second if this
   * conditional expression evaluates to false.
   */
  def mkIfElse(ifBranch: CodeBuilder, elseBranch: CodeBuilder): CodeBuilder = {
    val g   = new LabelGroup
    val end = new Label("end if-else", g)
    val t   = new Label("if branch", g)
    val f   = new Label("else branch", g)
    branch(t, Fall) |+| Lbl(f) |+| elseBranch |+| Goto(end) |+| Lbl(t) |+| ifBranch |+| Lbl(end)
  }
  
  /**
   * Produces code for a while loop whose condition is this boolean expression and
   * whose body is the code specified, embedding the given labels in the loop code as
   * appropriate so that they may be used as the targets of breaks and continues.
   * 
   * Note: The code generated by `cond.mkFor(body, incr)(brk, cont)` differs from that
   * generated by `cond.mkWhile(body |+| incr)(brk, cont)` in the placement of the `cont`
   * label.  The method `mkFor` places `cont` ''before'' `incr`, while `mkWhile` places
   * it ''after''.  This is a significant distinction.
   *
   * @param loopBody  the body of the loop to be created
   * @param brk  a label that this method is to place where a break statement should
   *             cause control to flow, namely, after the end of the loop
   * @param cont  the label that a continue statement will cause a jump to; this
   *              label is to be placed before the test of the loop
   * 
   * @return code for a loop that repeatedly executes the specified loop-body-code
   *         for as long as this conditional expression evaluates to true
   */
  def mkWhile(loopBody: CodeBuilder)(brk: Label, cont: Label): CodeBuilder = {
    val g    = brk.group
    val top  = new Label("top of while", g)
    
    Goto(cont) |+|
    Lbl(top)  |+| loopBody |+|
    Lbl(cont) |+| branch(top, Fall) |+|
    Lbl(brk)
  }
  
  /**
   * Produces code for a for loop whose condition is this conditional expression and
   * whose body and increment are as specified, embedding the given labels in the loop
   * code as appropriate so that they may be used as the targets of breaks and continues.
   * 
   * Note: The code generated by `cond.mkFor(body, incr)(brk, cont)` differs from that
   * generated by `cond.mkWhile(body |+| incr)(brk, cont)` in the placement of the `cont`
   * label.  The method `mkFor` places `cont` ''before'' `incr`, while `mkWhile` places
   * it ''after''.  This is a significant distinction.
   * 
   * @param loopBody  the body of the loop to be created
   * @param incrCode  
   * @param brk  a label to the place where a break statement should cause control
   *             to flow, namely, after the end of the loop
   * @param cont  the label that a continue statement will cause a jump to; this
   *              label is to be placed before the ''increment'' of the loop
   * 
   * @return code for a loop that repeatedly executes the specified loop-body-code,
   *         followed by the specified increment code, for as long as this conditional
   *         expression evaluates to true
   */
  def mkFor(loopBody: CodeBuilder, incrCode: CodeBuilder)(brk: Label, cont: Label): CodeBuilder = {
    val g    = brk.group
    val top  = new Label("top of for", g)
    val cond = new Label("cond of for", g)
    
    Goto(cond) |+|
    Lbl(top)  |+| loopBody |+|
    Lbl(cont) |+| incrCode |+|
    Lbl(cond) |+| branch(top, Fall) |+|
    Lbl(brk)
  }
}

/**
 * A conditional expression that branches based on the boolean value pushed onto the operand
 * stack by the given eval-underlying code.  This class is used for boolean variables, among
 * other things.
 */
private class CondValueExpr(evalUnderCode: => CodeBuilder, val typeOf: Type) extends ConditionalExpr {
  override def evalUnder = evalUnderCode
  
  def branch(t: Target, f: Target) = (t, f) match {
    case (Jump(tLbl), Jump(fLbl)) => evalUnderCode |+| Branch(IsTrue, tLbl) |+| Goto(fLbl)
    case (Jump(tLbl), Fall)       => evalUnderCode |+| Branch(IsTrue, tLbl)
    case (Fall,       Jump(fLbl)) => evalUnderCode |+| BranchNot(IsTrue, fLbl)
    
    case (Fall, Fall) => throw new AssertionError("impl error: no reason why both branches should be Fall")
  }
}

/**
 * A conditional expression corresponding to the logical negation of the
 * specified conditional expression.
 */
private class Not(b: ConditionalExpr, val typeOf: Type) extends ConditionalExpr {
  def branch(trueBr: Target, falseBr: Target): CodeBuilder =
    b.branch(falseBr, trueBr)
}

/**
 * A conditional expression corresponding to the logical conjunction of the
 * given conditional expressions.
 */
private class And(b1: ConditionalExpr, b2: ConditionalExpr, val typeOf: Type) extends ConditionalExpr {
  def branch(trueBr: Target, falseBr: Target): CodeBuilder = {
    val g    = new LabelGroup
    val btwn = new Label("between and", g)
    val end  = new Label("end and", g)
    
    b1.branch(Fall, falseBr.replaceFall(end)) |+| Lbl(btwn) |+| b2.branch(trueBr, falseBr)
  }
}

/**
 * A conditional expression corresponding to the logical disjunction of the
 * given conditional expressions.
 */
private class Or(b1: ConditionalExpr, b2: ConditionalExpr, val typeOf: Type) extends ConditionalExpr {
  def branch(trueBr: Target, falseBr: Target): CodeBuilder = {
    val g    = new LabelGroup
    val btwn = new Label("between or", g)
    val end  = new Label("end or", g)
    
    b1.branch(trueBr.replaceFall(end), Fall) |+| Lbl(btwn) |+| b2.branch(trueBr, falseBr)
  }
}

/**
 * A conditional expression corresponding to the comparison of two expressions
 * for equality or ordering.
 * 
 * @todo add support for comparing strings for order
 * @todo remove support for comparing arrays and structs
 * @todo restrict comparability of slices to comparison with `nil`
 */
private sealed abstract class CompExpr(comp: Comparison) extends ConditionalExpr {
  val typeOf = scope.UniverseScope.bool
  protected val e1, e2: Expr
  
  protected def stackingCode = e1.evalUnder |+| e2.evalUnder
  
  private[expr] def branch(trueBr: Target, falseBr: Target): CodeBuilder = (trueBr, falseBr) match {
    case (Jump(tLbl), Jump(fLbl)) => stackingCode |+| Branch(comp, tLbl) |+| Goto(fLbl)
    case (Jump(tLbl), Fall)       => stackingCode |+| Branch(comp, tLbl)
    case (Fall,       Jump(fLbl)) => stackingCode |+| BranchNot(comp, fLbl)
    
    case (Fall, Fall) =>
      throw new AssertionError("impl error: no reason why both branches should be Fall")
  }
}

private case class ObjEquals   (e1: Expr, e2: Expr) extends CompExpr(ObjEq)
private case class ObjNotEquals(e1: Expr, e2: Expr) extends CompExpr(ObjNe)

private case class BoolEquals   (e1: Expr, e2: Expr) extends CompExpr(BoolEq)
private case class BoolNotEquals(e1: Expr, e2: Expr) extends CompExpr(BoolNe)

private case class NumEquals    (e1: Expr, e2: Expr, numT: Arith) extends CompExpr(NumEq(numT))
private case class NumNotEquals (e1: Expr, e2: Expr, numT: Arith) extends CompExpr(NumNe(numT))
private case class LessThan     (e1: Expr, e2: Expr, numT: Arith) extends CompExpr(NumLt(numT))
private case class GreaterThan  (e1: Expr, e2: Expr, numT: Arith) extends CompExpr(NumGt(numT))
private case class LessEquals   (e1: Expr, e2: Expr, numT: Arith) extends CompExpr(NumLeq(numT))
private case class GreaterEquals(e1: Expr, e2: Expr, numT: Arith) extends CompExpr(NumGeq(numT))

